# #####################################
# #                                   #
# #    NOT FULLY FUNCTIONAL CODE      #
# #                                   #
# #####################################
#
# This file provides snippets of the code used to create the EPD-ConvLSTM model
# described in the paper.  NOTE: this is not code that is used to train the
# model--this is simply the code needed to construct the model that would then
# be ready for training.  And even then, the code is not likely 100% complete.
# But these are exact code snippets taken from code used to generate the models
# in the paper.  Some knowledge of Keras and Tensorflow is likely necessary to
# understand these code snippets.
#
# The create_model function at the very end is what ultimately creates a model
# that is ready for training.  It expects to be passed in the input and label
# shapes, which theoretically, you'd have if you had a dataset prepped for
# training.  The shape for training data used in the paper was (batch, time,
# height, width, channels).  The shape for label data was the same, with
# channels = 1.
#
# Note that this code was copy/pasted out of multiple files from our source
# code into this file and this file was never explicitly run.  Bugs from that
# process may have been introduced.
#
# Note that the implementation of the LSTM assumes that each data point
# consists of 8 time points, with all of the recurrant state in the LSTM being
# reset for each data point.  This was sufficient for our analyses, but it
# suboptimal if faster inference speed is required.  Instead of using the LSTM
# on 8 time steps at a time, it would be possible to feed in an initial 8 time
# steps, then save the internal state of the LSTM, and then perform future
# inferences on just a single time step at a time, where for each time step,
# the internal state of the LSTM is passed in.  This code does not reflect
# doing that (though can be relatively straight forwardly modified to do so).

#
# Imports
#

from typing import Any, Dict, Iterable, List, Optional, Sequence
from ml_collections import config_dict
import tensorflow.compat.v2 as tf


#
# See create_epd_snippets.py.snippets file for the ChannelEncodingModel code.
# The EPD-ConvLSTM model uses the same encoding code the EPD model does.
#


#
# Code that adds the EPD-ConvLSTM layers to the network.
#
# Note that the EPD2DCONVLSTMModel is a subclass of the EPD2DModel, which is
# located in the create_epd_snippets.py.snippets file.
#


class ReshapeWithBatch(tf.keras.layers.Layer):
  """A Reshape layer that allows reshaping to effect the batch dimension.

  The stock tf.keras.layers.Reshape layer will not allow layers to be merged
  with the batch dimension.  E.g., assume the input to the layer had a shape
  (None, 3, 4) and the desired reshape was to get to was (None, 4).  I.e., the
  2nd dimension is being merged with the batch channel (which at graph-creation
  time, is unknown in size).  There's no way to accomplish this with a Layer
  like:

    tf.layers.Reshape(target_shape=(-1, 10)) will result

  That command would not change the size at all, as it would assume the -1 size
  was fine staying at '3'.  We want the -1 to indicate that the None and 2nd
  dimension are merged into a new, larger, batch dimension.  So, instead, use
  this layer like:

    tf.layers.ReshapeWithBatch(target_shape=(-1, 10))
  """

  def __init__(self, target_shape: Iterable[int], **kwargs):
    super(ReshapeWithBatch, self).__init__(**kwargs)
    self.target_shape = target_shape

  def get_config(self):
    config = super(ReshapeWithBatch, self).get_config()
    config.update({'target_shape': self.target_shape})
    return config

  def call(self, inputs):
    return tf.reshape(inputs, shape=self.target_shape)


class EPD2DCONVLSTMModel(EPD2DModel):
  """An encode-processor-decoder model with a ConvLSTM."""

  def get_custom_objects(self) -> Dict[str, Any]:
    custom_objects = {'ReshapeWithBatch': ReshapeWithBatch}
    custom_objects.update(EPD2DModel().get_custom_objects())
    return custom_objects

  @classmethod
  def get_default_config(cls)-> config_dict.ConfigDict:
    """Gets the config."""
    cfg = config_dict.ConfigDict()

    # The settings to pass to the epd_2d model.
    cfg.epd_2d = EPD2DModel.get_default_config()

    # The number of filters to use in the ConvLSTM.
    cfg.num_filters = 64

    # The size of the kernel to use in the ConvLSTM.
    cfg.kernel = 3

    # The activation to use in the ConvLSTM layer.
    cfg.activation = 'tanh'

    # The recurrent activation to use in the ConvLSTM layer.
    cfg.recurrent_activation = 'hard_sigmoid'

    # The number of sequential ConvLSTM layers to add.  A skip link is added
    # around each layer.
    cfg.blocks = 1

    # The amount of kernel or recurrent regularization to use.  Leave as 0.0 for
    # no regularization.  Either set l1, l2 xor l1_l2 values to non-zero values.
    cfg.reg_kernel_l1 = 0.0
    cfg.reg_kernel_l2 = 0.0
    cfg.reg_kernel_l1_l2 = 0.0
    cfg.reg_rec_l1 = 0.0
    cfg.reg_rec_l2 = 0.0
    cfg.reg_rec_l1_l2 = 0.0
    cfg.both_l1 = 0.0
    cfg.both_l2 = 0.0
    cfg.both_l1_l2 = 0.0

    # The amount of dropout and recurrent dropout to use in the ConvLSTM2D
    # layers.  0.0 for no dropout.
    cfg.dropout = 0.0
    cfg.rec_dropout = 0.0

    return cfg

  def verify_input_shapes(
      self, inputs: Sequence[tf.Tensor],
      _: Optional[config_dict.ConfigDict]):
    if len(inputs) != 1:
      raise ValueError('There must be 1 input.')
    if len(inputs[0].shape) != 5:
      raise ValueError(
          'Expecting input[0] to have shape (b, t, h, w, c).')

  def _merge_batch_time(self, network: tf.Tensor, name: str) -> tf.Tensor:
    """Merges the first two dimensions of network into a single dimension."""
    network = ReshapeWithBatch(
        target_shape=(
            -1, network.shape[2], network.shape[3], network.shape[4]),
        name=name + '/reshape')(network)
    return network

  def _split_merged_batch_time(
      self,
      num_time_points_per_series: int,
      network: tf.Tensor,
      name: str) -> tf.Tensor:
    """Splits the first dimension into a batch and t dimension."""
    network = ReshapeWithBatch(
        target_shape=[-1, num_time_points_per_series, network.shape[1],
                      network.shape[2], network.shape[3]],
        name=name + '/reshape')(network)
    return network

  def _add_lstm_with_automatic_state(
      self,
      cfg: config_dict.ConfigDict,
      num_time_points_per_series: int,
      network: tf.Tensor,
      name: str) -> tf.Tensor:
    """Adds an LSTM to the network.

    The network is assumed to be of shape (X, h, w, c), where X is the batch
    channel that includes the time points of b time series of length t, so X has
    b*t entries in it, where each sequence of t batches actually belong to a
    time series.

    Args:
      cfg:  The configuration.
      num_time_points_per_series: The number of points in each time series.
      network: The network to add the lstm to.
      name: The base name to append graph names to.

    Returns:
      The network after the LSTM is added.
    """
    # Get the regularizers.
    kernel_regularizer = None
    rec_regularizer = None
    if cfg.reg_kernel_l1 > 0.0:
      kernel_regularizer = tf.keras.regularizers.L1(cfg.reg_kernel_l1)
    elif cfg.reg_kernel_l2 > 0.0:
      kernel_regularizer = tf.keras.regularizers.L2(cfg.reg_kernel_l2)
    elif cfg.reg_kernel_l1_l2 > 0.0:
      kernel_regularizer = tf.keras.regularizers.L1L2(
          l1=cfg.reg_kernel_l1_l2, l2=cfg.reg_kernel_l1_l2)

    if cfg.reg_rec_l1 > 0.0:
      rec_regularizer = tf.keras.regularizers.L1(cfg.reg_rec_l1)
    elif cfg.reg_rec_l2 > 0.0:
      rec_regularizer = tf.keras.regularizers.L2(cfg.reg_rec_l2)
    elif cfg.reg_rec_l1_l2 > 0.0:
      rec_regularizer = tf.keras.regularizers.L1L2(
          l1=cfg.reg_rec_l1_l2, l2=cfg.reg_rec_l1_l2)

    if cfg.both_l1 > 0.0:
      rec_regularizer = tf.keras.regularizers.L1(cfg.both_l1)
      kernel_regularizer = tf.keras.regularizers.L1(cfg.both_l1)
    elif cfg.both_l2 > 0.0:
      rec_regularizer = tf.keras.regularizers.L2(cfg.both_l2)
      kernel_regularizer = tf.keras.regularizers.L2(cfg.both_l2)
    elif cfg.both_l1_l2 > 0.0:
      rec_regularizer = tf.keras.regularizers.L1L2(
          l1=cfg.both_l1_l2, l2=cfg.both_l1_l2)
      kernel_regularizer = tf.keras.regularizers.L1L2(
          l1=cfg.both_l1_l2, l2=cfg.both_l1_l2)

    print('Using kernel_regularizer: ', kernel_regularizer)
    print('Using rec_regularizer: ', rec_regularizer)
    logging.info('MODEL CREATE: Using dropout: %f', cfg.dropout)
    logging.info('MODEL CREATE: Using rec dropout: %f', cfg.rec_dropout)

    # Verify this network is of the expected shape.
    if len(network.shape) != 4:
      raise ValueError('Expecting shape (X, h, w, c).')

    # Split the batch channel into a batch and time channel.
    network = self._split_merged_batch_time(
        num_time_points_per_series, network, name + '/split_batch')

    # Apply the ConvLSTM.
    for i in range(cfg.blocks):
      # Mark that the network is the beginning of a skip link.
      skip_start = network

      # Add in the ConvLSTM2d layer.
      network = tf.keras.layers.ConvLSTM2D(
          cfg.num_filters,
          cfg.kernel,
          padding='same',
          data_format='channels_last',
          activation=cfg.activation,
          recurrent_activation=cfg.recurrent_activation,
          return_sequences=True,
          return_state=False,
          kernel_regularizer=kernel_regularizer,
          recurrent_regularizer=rec_regularizer,
          dropout=cfg.dropout,
          recurrent_dropout=cfg.rec_dropout,
          name=name + f'/ConvLSTM2D_{i}')(network)

      # Create the skip link.
      skip_link = tf.keras.layers.Conv3D(
          filters=cfg.num_filters,
          kernel_size=1,
          strides=1,
          padding='same',
          name=name + f'/SkipLink_{i}')(skip_start)

      # Merge the skip link with the output of the convolution.
      network = tf.keras.layers.Add(name=name + f'/MergeWithSkip_{i}')(
          [network, skip_link])

    # Reshape the data so that the time series are merged back into the batches.
    network = self._merge_batch_time(network, name + '/merge_batch')

    return network

  def get_network_with_automatic_state(
      self,
      inputs: Sequence[tf.Tensor],
      cfg: config_dict.ConfigDict,
      name: str) -> List[tf.Tensor]:
    """Adds the network."""
    if len(inputs) != 1:
      raise ValueError('Expecting just a single input.')
    # NOTE: The number of time points in the sequences is hard-coded here to 8,
    # though nothing about this model necessitates that.
    num_time_points_per_series = 8

    # Merge the batch and time dimensions together, so all time points are
    # treated as separate batches.
    if len(inputs) != 1:
      raise ValueError('There should be just one input here.')
    network = inputs[0]  # (b, t, h, w, c)
    network = (self._merge_batch_time(
        network, name + '/merge_batch'))  # (b*t, h, w, c)

    # Build the EPD stages.
    network = self._add_input_processing(
        cfg.epd_2d, network, name + '/InputProcessing')
    network = self._add_encoder(cfg.epd_2d, network, name + '/Encoder')
    network = self._add_processor(cfg.epd_2d, network, name + '/Processor')
    network = self._add_lstm_with_automatic_state(
        cfg, num_time_points_per_series, network, name + '/ConvLSTM')
    network = self._add_decoder(cfg.epd_2d, network, name + '/Decoder')
    network = self._add_tail(cfg.epd_2d, network, name + '/Tail')

    # Split the merged batches back into batches of time series.
    network = self._split_merged_batch_time(
        num_time_points_per_series, network, name + '/split_batch')

    return [network]

  def create(
      self,
      inputs: Sequence[tf.Tensor],
      cfg: config_dict.ConfigDict,
      name: str = '') -> List[tf.Tensor]:
    """Creates the encoder-processor-decoder model."""
    # Verify the inputs.
    self.verify_input_shapes(inputs, cfg)

    # Verify the epd_2d config is okay.
    self._verify_cfg(cfg.epd_2d)

    results = self.get_network_with_automatic_state(
        inputs, cfg, name=name + '/WithAutoStates')

    return results


#
# Code to create the EPD-ConvLSTM model used in the paper.
#

def get_config() -> config_dict.ConfigDict:
  """Gets the config."""
  cfg = config_dict.ConfigDict()

  # If true, replace the categorical input with an embedding.
  cfg.use_embedding = False

  # The cfg for the embedding layer.  Not used if cfg.embedding is false.
  cfg.embedding = ChannelEncodingModel.get_default_config()

  # The configuration for the epd_2d model.
  cfg.lstm = EPD2DCONVLSTMModel.get_default_config()

  cfg.lock()

  return cfg


def get_custom_objects() -> Dict[str, Any]:
  """Returns the custom objects used by this model."""
  custom_objects = {}
  custom_objects.update(
      ChannelEncodingModel().get_custom_objects())
  custom_objects.update(
      EPD2DCONVLSTMModel().get_custom_objects())
  return custom_objects


def create_head(
    input_networks: List[tf.Tensor],
    cfg: config_dict.ConfigDict,
    expected_shape_wo_batch: Tuple[int, int, int, int],
) -> List[tf.Tensor]:
  """Creates the head of the IJWF-time model."""
  if len(expected_shape_wo_batch) != 4:
    raise ValueError('Expected shape should be (t, h, w, c), no batch.')
  network = input_networks[0]
  network = tf.ensure_shape(network, (None,) + expected_shape_wo_batch)
  if len(network.shape) != 5:
    raise ValueError('Input data should have the shape (b, t, h, w, c).')
  if network.shape[2] != network.shape[3]:
    raise ValueError('Expecting square input.')
  logging.info(f'The input network shape was: {network.shape}')

  # The root name all layers added to this model will have.
  name = 'lstm'

  # Add the embedding layer, if needed.
  if cfg.use_embedding:
    channel_encoding_ml_model = ChannelEncodingModel()
    [network] = channel_encoding_ml_model.create([network],
                                                 cfg.embedding,
                                                 name=name + '/embedding')

  # Get the starting LSTM state.
  if len(input_networks) != 1:
    raise ValueError('Expecting one input.')
  model_inputs = [network]

  model = EPD2DCONVLSTMModel()
  [network] = model.create(
      model_inputs, cfg.lstm, name=name + '/epd2dconvlstm')
  return [network]


# Code to call to create the model.
def create_model(input_shape: tf.TensorShape,
                 label_shape: tf.TensorShape) -> tf.keras.Model:
  """Gets the desired model.

  Args:
    cfg: The entire config.
    input_shape: The shape of the input.
    label_shape: The shape of the label.

  Returns:
    A tuple containing: 1) the model
  """

  # Create the input, ensuring the input has the correct shape.
  input_layers = []
  the_input = tf.keras.Input(
      shape=input_shape[1:], name='input_image'  # exclude batch dimension
  )
  input_layers.append(the_input)

  output_layers = create_head(
      input_layers, get_config(), input_shape[1:])

  keras_model = tf.keras.Model(input_layers, output_layers)

  optimizer = # Set up the optimizer you want.
  metrics = # Set up the metrics you want.
  loss_function = # Set up the loss function you want.

  keras_model.compile(
      optimizer=optimizer,
      loss=loss_functions,
      metrics=metrics_per_output,
      run_eagerly=cfg.training.train_in_eager_mode)
  keras_model.summary()

  return keras_model
